Skip to content

用二维 tensor 举例最直观。

假设 dec_out_list 里有 3 个二维 tensor:

python
y0 = torch.tensor([
    [1, 2],
    [3, 4],
])

y1 = torch.tensor([
    [10, 20],
    [30, 40],
])

y2 = torch.tensor([
    [100, 200],
    [300, 400],
])

每个 shape 都是:

python
[2, 2]

现在执行:

python
torch.stack([y0, y1, y2], dim=-1)

因为 dim=-1 是在最后新增一个维度,所以结果 shape 变成:

python
[2, 2, 3]

结果可以理解为:

python
stacked = [
    [
        [1, 10, 100],
        [2, 20, 200],
    ],
    [
        [3, 30, 300],
        [4, 40, 400],
    ],
]

也就是说,原来同一个位置的值被放到新维度里:

python
stacked[0, 0, :] = [y0[0,0], y1[0,0], y2[0,0]]
                 = [1, 10, 100]

stacked[0, 1, :] = [y0[0,1], y1[0,1], y2[0,1]]
                 = [2, 20, 200]

然后执行:

python
stacked.sum(-1)

就是对最后一维求和:

python
[
    [1 + 10 + 100,   2 + 20 + 200],
    [3 + 30 + 300,   4 + 40 + 400],
]

结果是:

python
tensor([
    [111, 222],
    [333, 444],
])

所以:

python
torch.stack(dec_out_list, dim=-1).sum(-1)

等价于:

python
y0 + y1 + y2

只是 stack + sum 可以处理任意数量的尺度预测。

*记录并在线阅读我的笔记*